In [None]:
!pip install tensorly
!pip install ptflops


Collecting tensorly
  Downloading tensorly-0.7.0-py3-none-any.whl (198 kB)
[?25l[K     |█▋                              | 10 kB 15.2 MB/s eta 0:00:01[K     |███▎                            | 20 kB 16.7 MB/s eta 0:00:01[K     |█████                           | 30 kB 17.6 MB/s eta 0:00:01[K     |██████▋                         | 40 kB 12.1 MB/s eta 0:00:01[K     |████████▎                       | 51 kB 6.4 MB/s eta 0:00:01[K     |██████████                      | 61 kB 5.9 MB/s eta 0:00:01[K     |███████████▋                    | 71 kB 5.7 MB/s eta 0:00:01[K     |█████████████▎                  | 81 kB 6.3 MB/s eta 0:00:01[K     |███████████████                 | 92 kB 6.1 MB/s eta 0:00:01[K     |████████████████▌               | 102 kB 5.7 MB/s eta 0:00:01[K     |██████████████████▏             | 112 kB 5.7 MB/s eta 0:00:01[K     |███████████████████▉            | 122 kB 5.7 MB/s eta 0:00:01[K     |█████████████████████▌          | 133 kB 5.7 MB/s eta 0:00:01

In [1]:
# helper function for training(fine-tuning)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
from ptflops import get_model_complexity_info

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Pytorch transforms composition
transform = transforms.Compose(
    [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)
# Pytorch Datasets class
train_dataset = datasets.CIFAR10(
    "./data/cifar10", train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
    "./data/cifar10", train=False, download=True, transform=transform
)
# Pytorch Dataloader class
TRAIN_DATALOADER = DataLoader(
    train_dataset, batch_size=64, shuffle=True, num_workers=2
)
TEST_DATALOADER = DataLoader(
    test_dataset, batch_size=64, shuffle=False, num_workers=2
)

@torch.no_grad()
def test(model):
  model.eval()
  correct, total = 0, 0
  for i, data in enumerate(TEST_DATALOADER):
    inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)
    correct += (preds == labels).sum()
    total += labels.shape[0]
  print(f"Test acc: {correct/total*100.0:.2f}%")

def train(model, num_epochs=5):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters())
  for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(TRAIN_DATALOADER):
      inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)
      optimizer.zero_grad()
      outputs = model(inputs)
      
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      running_loss += loss.item()
    print('epoch : %d loss: %.3f' %(epoch + 1, running_loss / (i+1)))
    test(model)
  print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified


## 1. 모델 및 데이터로더 생성, 학습

In [18]:
# Create model
class MyModel(nn.Module):
    def __init__(self, model_conv, label=10):
        super().__init__()
        self.features = model_conv.features
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, label)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def create_model():
  model_conv = models.vgg11_bn(pretrained=True)
  return MyModel(model_conv)

def get_macs(model):
  model = model.to(DEVICE)
  # calculate macs
  macs, params = get_model_complexity_info(model, tuple([3, 32, 32]), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
  return macs, params

In [19]:
# model creation test
test_model = create_model().to(DEVICE)
print(test_model)

MyModel(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2

In [20]:
def count_model_params(
    model: torch.nn.Module,
) -> int:
    """Count model's parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

params_nums = count_model_params(test_model)
print(params_nums)

9231114


### Original 모델 MACs 확인

In [21]:
macs, params = get_macs(test_model)
print(f'original_model: {macs}')

MyModel(
  9.231 M, 100.000% Params, 0.154 GMac, 100.000% MACs, 
  (features): Sequential(
    9.226 M, 99.944% Params, 0.153 GMac, 99.996% MACs, 
    (0): Conv2d(0.002 M, 0.019% Params, 0.002 GMac, 1.195% MACs, 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(0.0 M, 0.001% Params, 0.0 GMac, 0.085% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.043% MACs, inplace=True)
    (3): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.043% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(0.074 M, 0.800% Params, 0.019 GMac, 12.317% MACs, 64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(0.0 M, 0.003% Params, 0.0 GMac, 0.043% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.021% MACs, inplace=True)
    (7): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.021% MA

### Fine tune 및 Original 모델 Acc 확인

In [None]:
print('Fine tune orignal model')
train(test_model, num_epochs = 10)

Fine tune orignal model
epoch : 1 loss: 0.687
Test acc: 79.76%
epoch : 2 loss: 0.404
Test acc: 84.28%
epoch : 3 loss: 0.263
Test acc: 84.32%
epoch : 4 loss: 0.180
Test acc: 84.19%
epoch : 5 loss: 0.131
Test acc: 84.91%
epoch : 6 loss: 0.102
Test acc: 84.98%
epoch : 7 loss: 0.080
Test acc: 85.66%
epoch : 8 loss: 0.067
Test acc: 85.48%
epoch : 9 loss: 0.064
Test acc: 85.41%
epoch : 10 loss: 0.056
Test acc: 84.45%
Finished Training


## 2. Conv에 rank parameter 추가(register_buffer)

In [22]:
for name, param in test_model.named_modules():
    if isinstance(param, nn.Conv2d):
        param.register_buffer('rank', torch.Tensor([0.5, 0.5])) # rank in, out                                               

## 3. Decomposition 수행

### 3-1. Tucker decomposition 함수

In [24]:
import copy
import tensorly as tl
from tensorly.decomposition import parafac, partial_tucker
from typing import List
# switch to the PyTorch backend
tl.set_backend('pytorch')

def tucker_decomposition_conv_layer(
      layer: nn.Module,
      normed_rank: List[int] = [0.5, 0.5],
  ) -> nn.Module:
      """Gets a conv layer,
      returns a nn.Sequential object with the Tucker decomposition.
      The ranks are estimated with a Python implementation of VBMF
      https://github.com/CasvandenBogaard/VBMF
      """
      if hasattr(layer, "rank"):
          normed_rank = getattr(layer, "rank")
      rank = [int(r * layer.weight.shape[i]) for i, r in enumerate(normed_rank)] # output channel * normalized rank
      rank = [max(r, 2) for r in rank]

      core, [last, first] = partial_tucker(
          layer.weight.data,
          modes=[0, 1],
          n_iter_max=2000000,
          rank=rank,
          init="svd",
      )

      # A pointwise convolution that reduces the channels from S to R3
      first_layer = nn.Conv2d(
          in_channels=first.shape[0],
          out_channels=first.shape[1],
          kernel_size=1,
          stride=1,
          padding=0,
          dilation=layer.dilation,
          bias=False,
      )

      # A regular 2D convolution layer with R3 input channels
      # and R3 output channels
      core_layer = nn.Conv2d(
          in_channels=core.shape[1],
          out_channels=core.shape[0],
          kernel_size=layer.kernel_size,
          stride=layer.stride,
          padding=layer.padding,
          dilation=layer.dilation,
          bias=False,
      )

      # A pointwise convolution that increases the channels from R4 to T
      last_layer = nn.Conv2d(
          in_channels=last.shape[1],
          out_channels=last.shape[0],
          kernel_size=1,
          stride=1,
          padding=0,
          dilation=layer.dilation,
          bias=True,
      )

      if hasattr(layer, "bias") and layer.bias is not None:
          last_layer.bias.data = layer.bias.data

      first_layer.weight.data = (
          torch.transpose(first, 1, 0).unsqueeze(-1).unsqueeze(-1)
      )
      last_layer.weight.data = last.unsqueeze(-1).unsqueeze(-1)
      core_layer.weight.data = core

      new_layers = [first_layer, core_layer, last_layer]
      return nn.Sequential(*new_layers)


### 3-2. Decomposition 수행

In [25]:
def decompose(module: nn.Module):
  """Iterate model layers and decompose"""
  model_layers = list(module.children())
  if not model_layers:
      return None
  for i in range(len(model_layers)):
      if type(model_layers[i]) == nn.Sequential:
          decomposed_module = decompose(model_layers[i])
          if decomposed_module:
              model_layers[i] = decomposed_module
      if type(model_layers[i]) == nn.Conv2d:
          model_layers[i] = tucker_decomposition_conv_layer(model_layers[i])
  return nn.Sequential(*model_layers)

In [26]:
test_model.features = decompose(test_model.features)
test_model = test_model.to(DEVICE)

In [27]:
params_nums = count_model_params(test_model)
print(params_nums)

3387024


### Decomposed 모델 MACs 확인

In [None]:
macs, params = get_macs(test_model)
print(f'decomposed_model: {macs}')

MyModel(
  3.387 M, 100.000% Params, 0.06 GMac, 100.000% MACs, 
  (features): Sequential(
    3.382 M, 99.849% Params, 0.06 GMac, 99.991% MACs, 
    (0): Sequential(
      0.003 M, 0.080% Params, 0.003 GMac, 4.634% MACs, 
      (0): Conv2d(0.0 M, 0.000% Params, 0.0 GMac, 0.010% MACs, 3, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): Conv2d(0.001 M, 0.017% Params, 0.001 GMac, 0.991% MACs, 2, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (2): Conv2d(0.002 M, 0.062% Params, 0.002 GMac, 3.633% MACs, 32, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): BatchNorm2d(0.0 M, 0.004% Params, 0.0 GMac, 0.220% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(0.0 M, 0.000% Params, 0.0 GMac, 0.110% MACs, inplace=True)
    (3): MaxPool2d(0.0 M, 0.000% Params, 0.0 GMac, 0.110% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Sequential(
      0.029 M, 0.850% Params, 0.007 GMac, 12.385% MA

### Fine tune 및 decomposed 모델 Acc 확인

In [None]:
print('Fine tune decomposed model')
train(test_model, num_epochs = 10)

Fine tune decomposed model
epoch : 1 loss: 0.360
Test acc: 81.87%
epoch : 2 loss: 0.270
Test acc: 81.75%
epoch : 3 loss: 0.225
Test acc: 82.76%
epoch : 4 loss: 0.194
Test acc: 81.47%
epoch : 5 loss: 0.168
Test acc: 83.38%
epoch : 6 loss: 0.153
Test acc: 82.21%
epoch : 7 loss: 0.141
Test acc: 82.24%
epoch : 8 loss: 0.122
Test acc: 84.49%
epoch : 9 loss: 0.118
Test acc: 83.12%
epoch : 10 loss: 0.108
Test acc: 83.40%
Finished Training
