In [2]:
import glob
import os
from typing import Any, Dict, List, Tuple, Union

import torch
import yaml
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder, VisionDataset

from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from model import CustomVGG
import torchvision

In [4]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

import torchvision
from torchvision import datasets, transforms

In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## DataLoader

In [6]:
from torch.utils.data import DataLoader, ConcatDataset, WeightedRandomSampler
from dataset import get_dataset, get_weighted_sampler, get_concat_dataset
import torch
import numpy as np
import matplotlib.pyplot as plt

input_size = 128
batch_size = 128
n_worker = 8

train_dataset, valid_dataset = get_dataset(input_size)
concat_dataset = get_concat_dataset()

sample_freq = np.bincount(train_dataset.targets + valid_dataset.targets)
sample_weight = np.array([1/sample_freq[x] for x in train_dataset.targets] + [1/sample_freq[x] for x in valid_dataset.targets])
sample_weight = torch.from_numpy(sample_weight)
sampler = WeightedRandomSampler(sample_weight.type('torch.DoubleTensor'), len(sample_weight)//2)

#sampler = get_weighted_sampler()

train_loader = DataLoader(concat_dataset, batch_size=batch_size, drop_last=True, sampler = sampler, num_workers=n_worker)

## Model

In [17]:
model = CustomVGG(bias=True)
model.to(device)
model.load_state_dict(torch.load('save/vgg9_final.pt'))
model.eval()

CustomVGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPoo

## Hook for Rank Calculation

In [8]:
feature_maps = [3, 7, 10, 14, 17, 21, 24]
rank_result = [torch.tensor(0.), torch.tensor(0.)]

In [9]:
def get_feature_hook(self, input, output):

    num_images, num_features = output.size(0), output.size(1)
    ranks = torch.tensor([torch.matrix_rank(output[i,j,:,:]).item() for i in range(num_images) for j in range(num_features)])
    ranks = ranks.view(num_images, -1).float()
    ranks = ranks.sum(axis=0)
    rank_result[0] = rank_result[0] * rank_result[1] + ranks
    rank_result[1] += num_images
    rank_result[0] = rank_result[0] / rank_result[1]

In [10]:
def inference(limit=5):
    
    model.eval()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            if batch_idx >= limit:
                break

            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

## Get Ranks

In [11]:
ranks_dict = {}

limit = 5
model.eval()

for i, cov_id in enumerate(feature_maps):
    
    print(f'feature {i+1}:')
    
    cov_layer = model.features[cov_id]
    handler = cov_layer.register_forward_hook(get_feature_hook)
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            if batch_idx >= limit:
                break
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
    handler.remove()
    
    ranks_dict[cov_id] = rank_result[0].numpy()
    
    if not os.path.isdir('rank_conv/'):
        os.mkdir('rank_conv/')
    np.save('rank_conv/'+'/rank_conv' + str(i + 1) + '.npy', ranks_dict[cov_id])

    rank_result[0] = torch.tensor(0.)
    rank_result[1] = torch.tensor(0.)
    print()

feature 1:

feature 2:

feature 3:

feature 4:

feature 5:

feature 6:

feature 7:



## Prune

In [18]:
model_pruned = CustomVGG(cfg=[[int(64*0.75)], [int(128*0.75)], [int(256*0.75), int(256*0.75)], [int(512*0.75), int(512*0.75)], [int(512*0.75), 512]], bias=True)
model_pruned.eval()

CustomVGG(
  (features): Sequential(
    (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d

In [19]:
def load_vgg_model(model, oristate_dict):

    state_dict = model.state_dict()
    last_select_index = None

    cnt=0
    prefix = 'rank_conv/rank_conv'
    subfix = ".npy"

    for name, module in model.named_modules():

        if isinstance(module, nn.Conv2d):

            cnt+=1

            oriweight = oristate_dict[name + '.weight']
            oribias = oristate_dict[name + '.bias']
            curweight =state_dict[name + '.weight']

            orifilter_num = oriweight.size(0)
            currentfilter_num = curweight.size(0)

            if orifilter_num != currentfilter_num:

                cov_id = cnt
                print('loading rank from: ' + prefix + str(cov_id) + subfix)
                rank = np.load(prefix + str(cov_id) + subfix)
                select_index = np.argsort(rank)[orifilter_num-currentfilter_num:]
                select_index.sort()

                if last_select_index is not None:
                    for pruned_filter_idx, ori_filter_idx in enumerate(select_index):
                        for pruned_featmap_idx, ori_featmap_idx in enumerate(last_select_index):
                            state_dict[name + '.weight'][pruned_filter_idx][pruned_featmap_idx] = \
                                oristate_dict[name + '.weight'][ori_filter_idx][ori_featmap_idx]
                        state_dict[name + '.bias'][pruned_filter_idx] = \
                                oristate_dict[name + '.bias'][ori_filter_idx]
                else:
                    for pruned_filter_idx, ori_filter_idx in enumerate(select_index):
                        state_dict[name + '.weight'][pruned_filter_idx] = \
                            oristate_dict[name + '.weight'][ori_filter_idx]
                        state_dict[name + '.bias'][pruned_filter_idx] = \
                                oristate_dict[name + '.bias'][ori_filter_idx]

                last_select_index = select_index

            elif last_select_index is not None:
                for filter_idx in range(orifilter_num):
                    for pruned_featmap_idx, ori_featmap_idx in enumerate(last_select_index):
                        state_dict[name + '.weight'][filter_idx][pruned_featmap_idx] = \
                            oristate_dict[name + '.weight'][filter_idx][ori_featmap_idx]
                state_dict[name + '.bias'] = oribias

            else:
                state_dict[name + '.weight'] = oriweight
                state_dict[name + '.bias'] = oribias
                last_select_index = None
        
        elif isinstance(module, nn.BatchNorm2d):
            
            if last_select_index is None:
                state_dict[name + '.weight'] = oristate_dict[name + '.weight']
                state_dict[name + '.bias'] = oristate_dict[name + '.bias']
                state_dict[name + '.running_mean'] = oristate_dict[name + '.running_mean']
                state_dict[name + '.running_var'] = oristate_dict[name + '.running_var']
                
            else:
                for pruned_featmap_idx, ori_featmap_idx in enumerate(last_select_index):
                    state_dict[name + '.weight'][pruned_featmap_idx] = oristate_dict[name + '.weight'][ori_featmap_idx]
                    state_dict[name + '.bias'][pruned_featmap_idx] = oristate_dict[name + '.bias'][ori_featmap_idx]
                    state_dict[name + '.running_mean'][pruned_featmap_idx] = oristate_dict[name + '.running_mean'][ori_featmap_idx]
                    state_dict[name + '.running_var'][pruned_featmap_idx] = oristate_dict[name + '.running_var'][ori_featmap_idx]            
            
        elif isinstance(module, nn.Linear):
            
            state_dict[name + '.weight'] = oristate_dict[name + '.weight']
            state_dict[name + '.bias'] = oristate_dict[name + '.bias']

    model.load_state_dict(state_dict)

In [20]:
load_vgg_model(model_pruned, model.state_dict())
torch.save(model_pruned.state_dict(), 'save/pruned_final.pt')

loading rank from: rank_conv/rank_conv1.npy
loading rank from: rank_conv/rank_conv2.npy
loading rank from: rank_conv/rank_conv3.npy
loading rank from: rank_conv/rank_conv4.npy
loading rank from: rank_conv/rank_conv5.npy
loading rank from: rank_conv/rank_conv6.npy
loading rank from: rank_conv/rank_conv7.npy
