In [25]:
import copy

In [50]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

In [60]:
from neumeta.models import create_model_cifar10, create_densenet_model, create_mnist_model, fuse_module
from neumeta.utils import AverageMeter

from smooth.permute import PermutationManager, compute_tv_loss_for_network

In [35]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [36]:
def validate(model, val_loader, criterion, args=None, device='cpu'):
    # Set the model to training mode
    model.eval()

    # Initialize AverageMeter objects to track the losses
    losses = AverageMeter()
    top1 = AverageMeter()
    
    with torch.no_grad():
        # Iterate over the training data
        for x, target in tqdm(val_loader):
            # Preprocess input
            x, target = x.to(device), target.to(device)
            predict = model(x)
            
            loss = criterion(predict, target) 
            
            # Measure accuracy and record loss
            prec1 = accuracy(predict.data, target, topk=(1,))[0].item()
            losses.update(loss.item(), x.size(0))
            top1.update(prec1, x.size(0))

    return losses.avg, top1.avg

In [37]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [42]:
normalize = transforms.Normalize(
    mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 
    std=[x/255.0 for x in [63.0, 62.1, 66.7]]
    )

transforms_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

In [43]:
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms_train, download=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform_test)

In [45]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [46]:
criterion = nn.CrossEntropyLoss()

## ResNet

In [3]:
resnet20 = create_model_cifar10('ResNet20', 64)
resnet20

Replace the last 2 block of layer3 with new block with hidden dim 64
Loading pretrained weights for resnet20


CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): Identity()
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (2): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer2): Sequential(


In [14]:
for k, w in resnet20.named_parameters():
    print(k, w.shape)

conv1.weight torch.Size([16, 3, 3, 3])
conv1.bias torch.Size([16])
layer1.0.conv1.weight torch.Size([16, 16, 3, 3])
layer1.0.conv1.bias torch.Size([16])
layer1.0.conv2.weight torch.Size([16, 16, 3, 3])
layer1.0.conv2.bias torch.Size([16])
layer1.1.conv1.weight torch.Size([16, 16, 3, 3])
layer1.1.conv1.bias torch.Size([16])
layer1.1.conv2.weight torch.Size([16, 16, 3, 3])
layer1.1.conv2.bias torch.Size([16])
layer1.2.conv1.weight torch.Size([16, 16, 3, 3])
layer1.2.conv1.bias torch.Size([16])
layer1.2.conv2.weight torch.Size([16, 16, 3, 3])
layer1.2.conv2.bias torch.Size([16])
layer2.0.conv1.weight torch.Size([32, 16, 3, 3])
layer2.0.conv1.bias torch.Size([32])
layer2.0.conv2.weight torch.Size([32, 32, 3, 3])
layer2.0.conv2.bias torch.Size([32])
layer2.0.downsample.0.weight torch.Size([32, 16, 1, 1])
layer2.0.downsample.0.bias torch.Size([32])
layer2.1.conv1.weight torch.Size([32, 32, 3, 3])
layer2.1.conv1.bias torch.Size([32])
layer2.1.conv2.weight torch.Size([32, 32, 3, 3])
layer2.1.c

In [18]:
for k in resnet20.learnable_parameter:
    print(k, resnet20.learnable_parameter[k].shape)


layer3.2.conv1.weight torch.Size([64, 64, 3, 3])
layer3.2.conv1.bias torch.Size([64])
layer3.2.conv2.weight torch.Size([64, 64, 3, 3])
layer3.2.conv2.bias torch.Size([64])


In [81]:
sum(p.numel() for p in resnet20.parameters())

271690

## LeNet

In [10]:
LeNet = create_mnist_model('LeNet', 32)
LeNet

MnistNet(
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (conv_3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (f_1): ReLU()
  (f_2): ReLU()
  (f_3): ReLU()
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (linear): Linear(in_features=128, out_features=10, bias=True)
)

In [12]:
for k, w in LeNet.named_parameters():
    print(k, w.shape)

conv_1.weight torch.Size([32, 1, 3, 3])
conv_1.bias torch.Size([32])
conv_2.weight torch.Size([64, 32, 5, 5])
conv_2.bias torch.Size([64])
conv_3.weight torch.Size([128, 64, 5, 5])
conv_3.bias torch.Size([128])
linear.weight torch.Size([10, 128])
linear.bias torch.Size([10])


In [21]:
for k in LeNet.learnable_parameter:
    print(k, LeNet.learnable_parameter[k].shape)

conv_1.weight torch.Size([32, 1, 3, 3])
conv_1.bias torch.Size([32])
conv_2.weight torch.Size([64, 32, 5, 5])
conv_2.bias torch.Size([64])
conv_3.weight torch.Size([128, 64, 5, 5])
conv_3.bias torch.Size([128])
linear.weight torch.Size([10, 128])
linear.bias torch.Size([10])


In [24]:
for k in LeNet.learnable_parameter:
    print(k, LeNet.learnable_parameter[k].data)

conv_1.weight tensor([[[[-0.0735, -0.0388, -0.1106],
          [ 0.3012,  0.1019, -0.3244],
          [ 0.1546, -0.0032, -0.1900]]],


        [[[ 0.2336, -0.0763, -0.0320],
          [-0.3323, -0.1091,  0.2338],
          [-0.3218,  0.2272, -0.2021]]],


        [[[ 0.0319, -0.1555,  0.3279],
          [ 0.2069, -0.1175, -0.1551],
          [ 0.2684, -0.2731,  0.0276]]],


        [[[ 0.2689,  0.2858, -0.0018],
          [ 0.1553,  0.0697,  0.1824],
          [-0.3107, -0.1283, -0.2914]]],


        [[[ 0.2154,  0.2647, -0.1805],
          [ 0.0115, -0.1471, -0.0709],
          [-0.1938, -0.0823, -0.2273]]],


        [[[ 0.0874,  0.2806, -0.2308],
          [-0.0901,  0.2634,  0.0506],
          [ 0.1742, -0.1050, -0.2218]]],


        [[[-0.0009, -0.1704, -0.2909],
          [ 0.1640,  0.1381,  0.0827],
          [-0.2975, -0.1517, -0.1731]]],


        [[[ 0.0685,  0.3230, -0.1109],
          [-0.2757, -0.3110, -0.0805],
          [ 0.0835,  0.2018, -0.1628]]],


        [[[-0.2939

In [80]:
sum(p.numel() for p in LeNet.parameters())

257802

## DenseNet

In [39]:
densenet_bc_40_12 = create_densenet_model('DenseNet', 40, 12, 0.5, True, 0.0).to(device)
densenet_bc_40_12

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kern

In [13]:
for k, w in densenet_bc_40_12.named_parameters():
    print(k, w.shape)

conv1.weight torch.Size([24, 3, 3, 3])
block1.layer.0.bn1.weight torch.Size([24])
block1.layer.0.bn1.bias torch.Size([24])
block1.layer.0.conv1.weight torch.Size([48, 24, 1, 1])
block1.layer.0.bn2.weight torch.Size([48])
block1.layer.0.bn2.bias torch.Size([48])
block1.layer.0.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.1.bn1.weight torch.Size([36])
block1.layer.1.bn1.bias torch.Size([36])
block1.layer.1.conv1.weight torch.Size([48, 36, 1, 1])
block1.layer.1.bn2.weight torch.Size([48])
block1.layer.1.bn2.bias torch.Size([48])
block1.layer.1.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.2.bn1.weight torch.Size([48])
block1.layer.2.bn1.bias torch.Size([48])
block1.layer.2.conv1.weight torch.Size([48, 48, 1, 1])
block1.layer.2.bn2.weight torch.Size([48])
block1.layer.2.bn2.bias torch.Size([48])
block1.layer.2.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.3.bn1.weight torch.Size([60])
block1.layer.3.bn1.bias torch.Size([60])
block1.layer.3.conv1.weight torch.Size([48,

In [None]:
densenet_checkpoint = 'toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth'
densenet_checkpoint = torch.load(densenet_checkpoint, map_location=device)

In [48]:
densenet_bc_40_12.load_state_dict(densenet_checkpoint['model_state_dict'])

<All keys matched successfully>

In [51]:
val_loss, acc = validate(densenet_bc_40_12, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:04<00:00, 38.51it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.36%





In [53]:
densenet_bc_40_12_test = copy.deepcopy(densenet_bc_40_12)

In [54]:
val_loss, acc = validate(densenet_bc_40_12_test, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 40.87it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.36%





In [77]:
sum(p.numel() for p in densenet_bc_40_12.parameters())

176122

In [55]:
fuse_module(densenet_bc_40_12_test)

In [56]:
val_loss, acc = validate(densenet_bc_40_12_test, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 41.18it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.37%





In [76]:
densenet_bc_40_12_test

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): BottleneckBlock(
        (bn1): Bat

In [58]:
for k, w in densenet_bc_40_12_test.named_parameters():
    print (k, w.shape)

conv1.weight torch.Size([24, 3, 3, 3])
block1.layer.0.bn1.weight torch.Size([24])
block1.layer.0.bn1.bias torch.Size([24])
block1.layer.0.conv1.weight torch.Size([48, 24, 1, 1])
block1.layer.0.conv1.bias torch.Size([48])
block1.layer.0.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.1.bn1.weight torch.Size([36])
block1.layer.1.bn1.bias torch.Size([36])
block1.layer.1.conv1.weight torch.Size([48, 36, 1, 1])
block1.layer.1.conv1.bias torch.Size([48])
block1.layer.1.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.2.bn1.weight torch.Size([48])
block1.layer.2.bn1.bias torch.Size([48])
block1.layer.2.conv1.weight torch.Size([48, 48, 1, 1])
block1.layer.2.conv1.bias torch.Size([48])
block1.layer.2.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.3.bn1.weight torch.Size([60])
block1.layer.3.bn1.bias torch.Size([60])
block1.layer.3.conv1.weight torch.Size([48, 60, 1, 1])
block1.layer.3.conv1.bias torch.Size([48])
block1.layer.3.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.

> Above shows that DenseNet still works even after it has been fuse_model, and fuse_model also worked

In [78]:
sum(p.numel() for p in densenet_bc_40_12_test.parameters())

175258

In [66]:
densenet_bc_40_12_test_smooth = copy.deepcopy(densenet_bc_40_12_test)
densenet_bc_40_12_test_smooth.to(device)

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): BottleneckBlock(
        (bn1): Bat

In [67]:
val_loss, acc = validate(densenet_bc_40_12_test_smooth, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 41.61it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.37%





In [68]:
print(f'Old TV original model: {compute_tv_loss_for_network(densenet_bc_40_12_test_smooth, lambda_tv=1.0).item()}')

Old TV original model: 428.705078125


In [71]:
input_tensor = torch.randn(1, 3, 32, 32).to(device)
permute_func = PermutationManager(densenet_bc_40_12_test_smooth, input_tensor)

In [72]:
permute_dict = permute_func.compute_permute_dict()

In [73]:
densenet_bc_40_12_smooth = permute_func.apply_permutations(permute_dict, ignored_keys=[
    ('conv1.weight', 'in_channels'),
    ('fc.weight', 'out_channels'),
    ('fc.bias', 'out_channels')
])

In [75]:
print(f'Old TV original model: {compute_tv_loss_for_network(densenet_bc_40_12_smooth, lambda_tv=1.0).item()}')

Old TV original model: 394.69140625


In [74]:
val_loss, acc = validate(densenet_bc_40_12_smooth, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 40.89it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.37%





> Prove that the model can still worked well even after it has been smoothed

In [79]:
sum(p.numel() for p in densenet_bc_40_12_smooth.parameters())

175258