In [None]:
# default_exp misc.fc_decomposer

# Fully-Connected Layers Decomposer

> Factorize heavy FC layers into smaller ones

We can factorize our big fully-connected layers and replace them by an approximation of two smaller layers. The idea is to make an SVD decomposition of the weight matrix, which will express the original matrix in a product of 3 matrices: $U \Sigma V^T$
With $\Sigma$ being a diagonal matrix with non-negative values along its diagonal (the singular values). We then define a value $k$ of singular values to keep and modify matrices $U$ and $V^T$ accordingly. The resulting will be an approximation of the initial matrix.

![alt text](imgs/svd.pdf "SVD Decomposition")

In [None]:
#all_slow

In [None]:
#hide
from nbdev.showdoc import *

%config InlineBackend.figure_format = 'retina'

In [None]:
#export
from fastai.vision.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

In [None]:
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [None]:
learn = Learner(dls, vgg16_bn(num_classes=2), metrics=accuracy)

In [None]:
learn.fit_one_cycle(3)

epoch,train_loss,valid_loss,accuracy,time
0,0.886644,0.652151,0.685386,00:22
1,0.692583,0.627857,0.685386,00:21
2,0.646516,0.622866,0.685386,00:22


In [None]:
class FC_Decomposer:

    def __init__(self):
        super().__init__()
        
    def decompose(self, model, percent_removed=0.5):

        new_model = copy.deepcopy(model)

        module_names = list(new_model._modules)

        for k, name in enumerate(module_names):

            if len(list(new_model._modules[name]._modules)) > 0:
                new_model._modules[name] = self.decompose(new_model._modules[name], percent_removed)

            else:
                if isinstance(new_model._modules[name], nn.Linear):
                    # Folded BN
                    layer = self.SVD(new_model._modules[name], percent_removed)

                    # Replace old weight values
                    new_model._modules[name] = layer # Replace the FC Layer by the decomposed version
        return new_model


    def SVD(self, layer, percent_removed):

        W = layer.weight.data
        U, S, V = torch.svd(W)
        L = int((1.-percent_removed)*U.shape[0])
        W1 = U[:,:L]
        W2 = torch.diag(S[:L]) @ V[:,:L].t()
        layer_1 = nn.Linear(in_features=layer.in_features, 
                    out_features=L, bias=False)
        layer_1.weight.data = W2

        layer_2 = nn.Linear(in_features=L, 
                    out_features=layer.out_features, bias=True)
        layer_2.weight.data = W1

        if layer.bias.data is None: 
            layer_2.bias.data = torch.zeros(*layer.out_features.shape)
        else:
            layer_2.bias.data = layer.bias.data

        return nn.Sequential(layer_1, layer_2)

In [None]:
fc = FC_Decomposer()

In [None]:
#hide
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
new_model = fc.decompose(learn.model)

In [None]:
new_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

We can see compare the amount of parameters before/after:

In [None]:
count_parameters(learn.model)

134277186

In [None]:
count_parameters(new_model)

91281476

This represents a decrease of ~40M parameters ! 

Now this is an approximation, so it isn't really lossless and we should expect to see a performance drop, which will be bigger as we keep fewer singular values. Here we have:

In [None]:
new_learn = Learner(dls, new_model, metrics=accuracy)

In [None]:
new_learn.validate()

(#2) [0.6868855357170105,0.6853856444358826]