Load the data

Train a simple MLP with and without pruning
    1. What structure? We can experiment with it

Evaluate the accurcary

Use the experiment class to calculate and compare bobs


In [1]:
import numpy as np
import csv
import pandas as pd
from src.train import train, evaluate
import torch
import torch.nn as nn
from src.networks_fashion_mnist import MLP_small_fashion_mnist
from torch.utils.data import random_split, DataLoader, Dataset
from src.utils_quantization import attach_weight_quantizers, toggle_quantization
from src.quantizer import DeadZoneLDZCompander
from torchvision import transforms, datasets

In [2]:

mnist_fashion_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.2860,), (0.3530,))
]
)

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=mnist_fashion_transform
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

In [3]:
def structured_sparsity_reg(model, lambda_linear=1e-4, lambda_conv=1e-4):
    reg = 0.0
    for m in model.modules():
        if isinstance(m, nn.Linear):
            W = m.weight  # [out, in]
            reg = reg + lambda_linear * torch.norm(W, dim=0).sum()
            #print(m, reg)
            #reg = reg + lambda_linear * torch.norm(W, dim=1).sum()
        elif isinstance(m, nn.Conv2d):
            W = m.weight  # [out_c, in_c, kH, kW]
            # Flatten each filter to vector then one group per output channel
            W_flat = W.view(W.size(0), -1)  # [out_c, in_c*kH*kW]
            reg = reg + lambda_conv * torch.norm(W_flat, dim=0).sum()
            #reg = reg + lambda_conv * torch.norm(W_flat, dim=1).sum()
    return reg


In [4]:
model = MLP_small_fashion_mnist()
total_size = len(training_data)
training_percentage = 0.05
training_size = round(total_size*training_percentage)

remaining, _  = random_split(training_data,[training_size, total_size - training_size]) 

val_split_percentage = 0.10
val_size = round(val_split_percentage*training_size)

train_dataset, val_dataset = random_split(remaining, [len(remaining) - val_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=256,shuffle=True)
val_loader = DataLoader(val_dataset)

LR_BASE = 0.001
LR_DZ = 0.001
LR_BIT = 0.001

WD = 0.0
WD_DZ = 2.5 #2.5
WD_BIT = 0.0 #2.5
REG_STRUCTURED = 0.1
QUANT_ARGS = {"fixed_bit_val": 8, "init_deadzone_logit": 3.0, "max_bits": 8, "learnable_deadzone": True, "learnable_bit": False}



In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
epochs = 1

attach_weight_quantizers(model=model,
                         exclude_layers=[],
                         quantizer=DeadZoneLDZCompander,
                         quantizer_kwargs=QUANT_ARGS,
                         enabled=True
                         )

base_params = []
dz_params = []
bit_params = []

for name, param in model.named_parameters():
    if 'logit_dz' in name:
        dz_params.append(param)
    elif 'logit_bit' in name:
        bit_params.append(param)
    else:
        base_params.append(param)

optimizer = torch.optim.AdamW([
    {'params': base_params, 'lr': LR_BASE, 'weight_decay': WD},
    {'params': dz_params, 'lr': LR_DZ, 'weight_decay': WD_DZ},
    {'params': bit_params, 'lr': LR_BIT, 'weight_decay': WD_BIT},
])

toggle_quantization(model=model,enabled=True)


train(epochs=epochs,model=model,optimizer=optimizer,criterion=criterion,train_loader=train_loader,val_loader=val_loader)


Attached weight quantizer to layer: net.0
Attached weight quantizer to layer: net.2
Attached weight quantizer to layer: net.4
Epoch [1/1] | Train Loss: 1.4696, Train Acc: 55.04% | Val Loss: 0.8230, Val Acc: 72.33%


In [None]:
from src.bobs_calculator import compare_model 

compare_model(model)

MLP_small_fashion_mnist(
  (net): Sequential(
    (0): ParametrizedLinear(
      in_features=784, out_features=256, bias=True
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): FakeQuantParametrization(
            (quantizer): DeadZoneLDZCompander()
          )
        )
      )
    )
    (1): ReLU()
    (2): ParametrizedLinear(
      in_features=256, out_features=128, bias=True
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): FakeQuantParametrization(
            (quantizer): DeadZoneLDZCompander()
          )
        )
      )
    )
    (3): ReLU()
    (4): ParametrizedLinear(
      in_features=128, out_features=10, bias=True
      (parametrizations): ModuleDict(
        (weight): ParametrizationList(
          (0): FakeQuantParametrization(
            (quantizer): DeadZoneLDZCompander()
          )
        )
      )
    )
  )
)
