## Prerequisite

In [1]:
import copy

In [2]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from tqdm import tqdm

In [3]:
from neumeta.models import create_model_cifar10, create_densenet_model, create_mnist_model, fuse_module
from neumeta.models.utils import fuse_module_densenet
from neumeta.utils import AverageMeter

from smooth.permute import PermutationManager, compute_tv_loss_for_network

In [4]:
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [5]:
def validate(model, val_loader, criterion, args=None, device='cpu'):
    # Set the model to training mode
    model.eval()

    # Initialize AverageMeter objects to track the losses
    losses = AverageMeter()
    top1 = AverageMeter()
    
    with torch.no_grad():
        # Iterate over the training data
        for x, target in tqdm(val_loader):
            # Preprocess input
            x, target = x.to(device), target.to(device)
            predict = model(x)
            
            loss = criterion(predict, target) 
            
            # Measure accuracy and record loss
            prec1 = accuracy(predict.data, target, topk=(1,))[0].item()
            losses.update(loss.item(), x.size(0))
            top1.update(prec1, x.size(0))

    return losses.avg, top1.avg

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
normalize = transforms.Normalize(
    mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 
    std=[x/255.0 for x in [63.0, 62.1, 66.7]]
    )

transforms_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

In [8]:
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transforms_train, download=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform_test)

In [9]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [10]:
criterion = nn.CrossEntropyLoss()

## ResNet

In [11]:
resnet20 = create_model_cifar10('ResNet20', 64)
resnet20

Replace the last 2 block of layer3 with new block with hidden dim 64
Loading pretrained weights for resnet20


CifarResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): Identity()
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
    (2): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): Identity()
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): Identity()
    )
  )
  (layer2): Sequential(


In [12]:
for k, w in resnet20.named_parameters():
    print(k, w.shape)

conv1.weight torch.Size([16, 3, 3, 3])
conv1.bias torch.Size([16])
layer1.0.conv1.weight torch.Size([16, 16, 3, 3])
layer1.0.conv1.bias torch.Size([16])
layer1.0.conv2.weight torch.Size([16, 16, 3, 3])
layer1.0.conv2.bias torch.Size([16])
layer1.1.conv1.weight torch.Size([16, 16, 3, 3])
layer1.1.conv1.bias torch.Size([16])
layer1.1.conv2.weight torch.Size([16, 16, 3, 3])
layer1.1.conv2.bias torch.Size([16])
layer1.2.conv1.weight torch.Size([16, 16, 3, 3])
layer1.2.conv1.bias torch.Size([16])
layer1.2.conv2.weight torch.Size([16, 16, 3, 3])
layer1.2.conv2.bias torch.Size([16])
layer2.0.conv1.weight torch.Size([32, 16, 3, 3])
layer2.0.conv1.bias torch.Size([32])
layer2.0.conv2.weight torch.Size([32, 32, 3, 3])
layer2.0.conv2.bias torch.Size([32])
layer2.0.downsample.0.weight torch.Size([32, 16, 1, 1])
layer2.0.downsample.0.bias torch.Size([32])
layer2.1.conv1.weight torch.Size([32, 32, 3, 3])
layer2.1.conv1.bias torch.Size([32])
layer2.1.conv2.weight torch.Size([32, 32, 3, 3])
layer2.1.c

In [13]:
for k in resnet20.learnable_parameter:
    print(k, resnet20.learnable_parameter[k].shape)


layer3.2.conv1.weight torch.Size([64, 64, 3, 3])
layer3.2.conv1.bias torch.Size([64])
layer3.2.conv2.weight torch.Size([64, 64, 3, 3])
layer3.2.conv2.bias torch.Size([64])


In [14]:
sum(p.numel() for p in resnet20.parameters())

271690

## LeNet

In [15]:
LeNet = create_mnist_model('LeNet', 32)
LeNet

MnistNet(
  (conv_1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=replicate)
  (conv_2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (conv_3): Conv2d(64, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=replicate)
  (f_1): ReLU()
  (f_2): ReLU()
  (f_3): ReLU()
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (linear): Linear(in_features=128, out_features=10, bias=True)
)

In [16]:
for k, w in LeNet.named_parameters():
    print(k, w.shape)

conv_1.weight torch.Size([32, 1, 3, 3])
conv_1.bias torch.Size([32])
conv_2.weight torch.Size([64, 32, 5, 5])
conv_2.bias torch.Size([64])
conv_3.weight torch.Size([128, 64, 5, 5])
conv_3.bias torch.Size([128])
linear.weight torch.Size([10, 128])
linear.bias torch.Size([10])


In [17]:
for k in LeNet.learnable_parameter:
    print(k, LeNet.learnable_parameter[k].shape)

conv_1.weight torch.Size([32, 1, 3, 3])
conv_1.bias torch.Size([32])
conv_2.weight torch.Size([64, 32, 5, 5])
conv_2.bias torch.Size([64])
conv_3.weight torch.Size([128, 64, 5, 5])
conv_3.bias torch.Size([128])
linear.weight torch.Size([10, 128])
linear.bias torch.Size([10])


In [18]:
for k in LeNet.learnable_parameter:
    print(k, LeNet.learnable_parameter[k].data)

conv_1.weight tensor([[[[ 0.2054,  0.3288,  0.2130],
          [-0.0336, -0.0210, -0.0629],
          [-0.2798,  0.2457,  0.0172]]],


        [[[-0.0524,  0.2950, -0.2679],
          [ 0.1621, -0.2832,  0.1234],
          [-0.0057,  0.2922,  0.2802]]],


        [[[-0.1033, -0.0040,  0.1100],
          [-0.0687,  0.1351, -0.1614],
          [-0.0263,  0.0230, -0.2629]]],


        [[[ 0.1576,  0.0304,  0.1967],
          [ 0.0450,  0.0725,  0.1436],
          [ 0.0100, -0.2416, -0.2341]]],


        [[[ 0.2208, -0.1208, -0.1912],
          [-0.1389,  0.2640, -0.1980],
          [ 0.1340, -0.0815, -0.0885]]],


        [[[-0.0146,  0.0241,  0.1426],
          [-0.0791,  0.1165, -0.0427],
          [-0.0554, -0.2918, -0.2729]]],


        [[[ 0.1892,  0.0591, -0.1823],
          [ 0.0567, -0.1519, -0.2491],
          [ 0.3289,  0.3224, -0.1803]]],


        [[[-0.0741,  0.2843,  0.1018],
          [-0.0447,  0.1989, -0.1547],
          [ 0.2663, -0.0488, -0.0726]]],


        [[[ 0.1249

In [19]:
sum(p.numel() for p in LeNet.parameters())

257802

## DenseNet

In [11]:
densenet_bc_40_12 = create_densenet_model('DenseNet', 40, 12, 0.5, True, 0.0).to(device)
densenet_bc_40_12

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kern

In [12]:
for k, w in densenet_bc_40_12.named_parameters():
    print(k, w.shape)

conv1.weight torch.Size([24, 3, 3, 3])
block1.layer.0.bn1.weight torch.Size([24])
block1.layer.0.bn1.bias torch.Size([24])
block1.layer.0.conv1.weight torch.Size([48, 24, 1, 1])
block1.layer.0.bn2.weight torch.Size([48])
block1.layer.0.bn2.bias torch.Size([48])
block1.layer.0.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.1.bn1.weight torch.Size([36])
block1.layer.1.bn1.bias torch.Size([36])
block1.layer.1.conv1.weight torch.Size([48, 36, 1, 1])
block1.layer.1.bn2.weight torch.Size([48])
block1.layer.1.bn2.bias torch.Size([48])
block1.layer.1.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.2.bn1.weight torch.Size([48])
block1.layer.2.bn1.bias torch.Size([48])
block1.layer.2.conv1.weight torch.Size([48, 48, 1, 1])
block1.layer.2.bn2.weight torch.Size([48])
block1.layer.2.bn2.bias torch.Size([48])
block1.layer.2.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.3.bn1.weight torch.Size([60])
block1.layer.3.bn1.bias torch.Size([60])
block1.layer.3.conv1.weight torch.Size([48,

In [13]:
densenet_checkpoint = 'toy/experiments/densenet_bc_40_12_baseline/densenet_bc_40_12_cifar10_baseline_best.pth'
densenet_checkpoint = torch.load(densenet_checkpoint, map_location=device)

In [14]:
densenet_bc_40_12.load_state_dict(densenet_checkpoint['model_state_dict'])

<All keys matched successfully>

In [15]:
val_loss, acc = validate(densenet_bc_40_12, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 40.61it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.36%





In [16]:
sum(p.numel() for p in densenet_bc_40_12.parameters())

176122

In [17]:
densenet_bc_40_12_test = copy.deepcopy(densenet_bc_40_12)

In [18]:
val_loss, acc = validate(densenet_bc_40_12_test, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 44.18it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.36%





In [19]:
fuse_module(densenet_bc_40_12_test)

In [20]:
val_loss, acc = validate(densenet_bc_40_12_test, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 44.43it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.37%





In [21]:
densenet_bc_40_12_test

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): BottleneckBlock(
        (bn1): Bat

In [22]:
for k, w in densenet_bc_40_12_test.named_parameters():
    print (k, w.shape)

conv1.weight torch.Size([24, 3, 3, 3])
block1.layer.0.bn1.weight torch.Size([24])
block1.layer.0.bn1.bias torch.Size([24])
block1.layer.0.conv1.weight torch.Size([48, 24, 1, 1])
block1.layer.0.conv1.bias torch.Size([48])
block1.layer.0.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.1.bn1.weight torch.Size([36])
block1.layer.1.bn1.bias torch.Size([36])
block1.layer.1.conv1.weight torch.Size([48, 36, 1, 1])
block1.layer.1.conv1.bias torch.Size([48])
block1.layer.1.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.2.bn1.weight torch.Size([48])
block1.layer.2.bn1.bias torch.Size([48])
block1.layer.2.conv1.weight torch.Size([48, 48, 1, 1])
block1.layer.2.conv1.bias torch.Size([48])
block1.layer.2.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.3.bn1.weight torch.Size([60])
block1.layer.3.bn1.bias torch.Size([60])
block1.layer.3.conv1.weight torch.Size([48, 60, 1, 1])
block1.layer.3.conv1.bias torch.Size([48])
block1.layer.3.conv2.weight torch.Size([12, 48, 3, 3])
block1.layer.

> Above shows that DenseNet still works even after it has been fuse_model, and fuse_model also worked

In [23]:
sum(p.numel() for p in densenet_bc_40_12_test.parameters())

175258

In [24]:
densenet_bc_40_12_test_new_fuse_module = copy.deepcopy(densenet_bc_40_12)

val_loss, acc = validate(densenet_bc_40_12_test_new_fuse_module, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 44.17it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.36%





In [25]:
fuse_module_densenet(densenet_bc_40_12_test_new_fuse_module)

Name: conv1 and prev_name: None
Name: block1.layer.0.bn1 and prev_name: conv1
Name: 24 and prev_name: 24
in here
Name: block1.layer.0.relu and prev_name: conv1
Name: block1.layer.0.conv1 and prev_name: conv1
Name: block1.layer.0.bn2 and prev_name: block1.layer.0.conv1
Name: 48 and prev_name: 48
in here
Name: block1.layer.0.conv2 and prev_name: block1.layer.0.conv1
Name: block1.layer.0.dropout and prev_name: block1.layer.0.conv2
Name: block1.layer.1.bn1 and prev_name: block1.layer.0.conv2
Name: 36 and prev_name: 12
Name: block1.layer.1.relu and prev_name: block1.layer.0.conv2
Name: block1.layer.1.conv1 and prev_name: block1.layer.0.conv2
Name: block1.layer.1.bn2 and prev_name: block1.layer.1.conv1
Name: 48 and prev_name: 48
in here
Name: block1.layer.1.conv2 and prev_name: block1.layer.1.conv1
Name: block1.layer.1.dropout and prev_name: block1.layer.1.conv2
Name: block1.layer.2.bn1 and prev_name: block1.layer.1.conv2
Name: 48 and prev_name: 12
Name: block1.layer.2.relu and prev_name: bl

In [26]:
densenet_bc_40_12_test_new_fuse_module

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 

In [27]:
densenet_bc_40_12_test_smooth = copy.deepcopy(densenet_bc_40_12_test)
densenet_bc_40_12_test_smooth.to(device)

DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
        (bn2): Identity()
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (2): BottleneckBlock(
        (bn1): Bat

In [28]:
val_loss, acc = validate(densenet_bc_40_12_test_smooth, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 44.28it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.37%





In [29]:
print(f'Old TV original model: {compute_tv_loss_for_network(densenet_bc_40_12_test_smooth, lambda_tv=1.0).item()}')

Old TV original model: 428.705078125


In [30]:
input_tensor = torch.randn(1, 3, 32, 32).to(device)
permute_func = PermutationManager(densenet_bc_40_12_test_smooth, input_tensor)

In [31]:
permute_dict = permute_func.compute_permute_dict()

In [32]:
densenet_bc_40_12_smooth = permute_func.apply_permutations(permute_dict, ignored_keys=[
    ('conv1.weight', 'in_channels'),
    ('fc.weight', 'out_channels'),
    ('fc.bias', 'out_channels')
])

In [33]:
print(f'Old TV original model: {compute_tv_loss_for_network(densenet_bc_40_12_smooth, lambda_tv=1.0).item()}')

Old TV original model: 395.0920715332031


In [34]:
val_loss, acc = validate(densenet_bc_40_12_smooth, val_loader, criterion, device=device)
print(f"Test on DenseNet-BC-40-12, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc:.2f}%")

100%|██████████| 157/157 [00:03<00:00, 44.24it/s]

Test on DenseNet-BC-40-12, Validation Loss: 0.2863, Validation Accuracy: 93.37%





> Prove that the model can still worked well even after it has been smoothed

In [35]:
sum(p.numel() for p in densenet_bc_40_12_smooth.parameters())

175258

In [36]:
for name, child in list(densenet_bc_40_12.named_children()):
    print(name, child)

conv1 Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
block1 DenseBlock(
  (layer): Sequential(
    (0): BottleneckBlock(
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (1): BottleneckBlock(
      (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias

In [37]:
for name, child in list(densenet_bc_40_12_test.named_children()):
    print(name, child)

conv1 Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
block1 DenseBlock(
  (layer): Sequential(
    (0): BottleneckBlock(
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1))
      (bn2): Identity()
      (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (1): BottleneckBlock(
      (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1))
      (bn2): Identity()
      (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (2): BottleneckBlock(
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_ru

In [38]:
def get_all_layers(module, prefix=''):
    layers = {}
    for name, child in module.named_children():
        full_name = f'{prefix}.{name}' if prefix else name
        if len(list(child.children())) > 0:
            # Get the children of a nested layer
            layers.update(get_all_layers(child, prefix=full_name))
        else:
            layers[full_name] = child
    return layers


In [39]:
for k, m in get_all_layers(densenet_bc_40_12).items():
    print(k, m)
    if isinstance(m, nn.BatchNorm2d):
        densenet_bc_40_12._modules[k] = nn.Identity()

conv1 Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
block1.layer.0.bn1 BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
block1.layer.0.relu ReLU(inplace=True)
block1.layer.0.conv1 Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
block1.layer.0.bn2 BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
block1.layer.0.conv2 Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
block1.layer.0.dropout Dropout(p=0.0, inplace=False)
block1.layer.1.bn1 BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
block1.layer.1.relu ReLU(inplace=True)
block1.layer.1.conv1 Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
block1.layer.1.bn2 BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
block1.layer.1.conv2 Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
block1.layer.1.dropout Dro

In [40]:
get_all_layers(densenet_bc_40_12)

{'conv1': Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 'block1.layer.0.bn1': Identity(),
 'block1.layer.0.relu': ReLU(inplace=True),
 'block1.layer.0.conv1': Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False),
 'block1.layer.0.bn2': Identity(),
 'block1.layer.0.conv2': Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 'block1.layer.0.dropout': Dropout(p=0.0, inplace=False),
 'block1.layer.1.bn1': Identity(),
 'block1.layer.1.relu': ReLU(inplace=True),
 'block1.layer.1.conv1': Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False),
 'block1.layer.1.bn2': Identity(),
 'block1.layer.1.conv2': Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 'block1.layer.1.dropout': Dropout(p=0.0, inplace=False),
 'block1.layer.2.bn1': Identity(),
 'block1.layer.2.relu': ReLU(inplace=True),
 'block1.layer.2.conv1': Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False),
 'block1.layer.2.

In [41]:
densenet_bc_40_12.named_parameters

<bound method Module.named_parameters of DenseNet3(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): DenseBlock(
    (layer): Sequential(
      (0): BottleneckBlock(
        (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (1): BottleneckBlock(
        (bn1): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv1): Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [42]:
get_all_layers(densenet_bc_40_12_test)

{'conv1': Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 'block1.layer.0.bn1': BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 'block1.layer.0.relu': ReLU(inplace=True),
 'block1.layer.0.conv1': Conv2d(24, 48, kernel_size=(1, 1), stride=(1, 1)),
 'block1.layer.0.bn2': Identity(),
 'block1.layer.0.conv2': Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 'block1.layer.0.dropout': Dropout(p=0.0, inplace=False),
 'block1.layer.1.bn1': BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 'block1.layer.1.relu': ReLU(inplace=True),
 'block1.layer.1.conv1': Conv2d(36, 48, kernel_size=(1, 1), stride=(1, 1)),
 'block1.layer.1.bn2': Identity(),
 'block1.layer.1.conv2': Conv2d(48, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 'block1.layer.1.dropout': Dropout(p=0.0, inplace=False),
 'block1.layer.2.bn1': BatchNorm2d(48, eps=1e-05, momentum=0.1, affine