In [1]:
import torch, timm, importlib

In [2]:
m = timm.create_model('resnet50', num_classes=100)

In [3]:
state_dict = torch.load('output/train/20230914-091722-resnet50-224/model_best.pth.tar')['state_dict']

In [4]:
m.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
m.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): ConvBN(
        (conv): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): ConvBN(
        (conv): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (skip): Identity()
    )
    (1): Bot

In [6]:
# fuse ConvBN
for layer in [m.layer1, m.layer2, m.layer3, m.layer4]:
    print("layer >>")
    for block in layer:
        # Fuse the weights in conv1 and conv3
        block.conv1.fuse_bn()
        print(block.conv1.fused_weight.size())
        block.conv3.fuse_bn()
        print(block.conv3.fused_weight.size())
        if block.act3 is not None:
            layer[0].skip.fuse_bn()
            print(layer[0].skip.fused_weight.size())

layer >>
torch.Size([512, 64, 1, 1])
torch.Size([64, 512, 1, 1])
torch.Size([512, 64, 1, 1])
torch.Size([64, 512, 1, 1])
torch.Size([512, 64, 1, 1])
torch.Size([64, 512, 1, 1])
layer >>
torch.Size([512, 64, 1, 1])
torch.Size([128, 512, 1, 1])
torch.Size([128, 64, 1, 1])
torch.Size([1024, 128, 1, 1])
torch.Size([128, 1024, 1, 1])
torch.Size([1024, 128, 1, 1])
torch.Size([128, 1024, 1, 1])
torch.Size([1024, 128, 1, 1])
torch.Size([128, 1024, 1, 1])
layer >>
torch.Size([1024, 128, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([256, 128, 1, 1])
torch.Size([1024, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([1024, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([1024, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([1024, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
torch.Size([1024, 256, 1, 1])
torch.Size([256, 1024, 1, 1])
layer >>
torch.Size([1024, 256, 1, 1])
torch.Size([512, 1024, 1, 1])
torch.Size([512, 256, 1, 1])
torch.Size([2048, 512, 1, 1])
torch.Size([512, 2048, 

In [7]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

def validate(model, data_dir, batch_size=32):
    # Set the model to evaluation mode
    model.eval()
    
    device = 'cpu'
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Define the transform
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load the data
    val_dataset = datasets.ImageFolder(root=data_dir, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    # Initialize metrics
    correct = 0
    total = 0
    loss = 0

    # Define the loss function
    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in val_loader:
            # Move data to the appropriate device
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Calculate loss
            loss += criterion(outputs, labels).item()

            # Get the predicted class
            _, predicted = torch.max(outputs.data, 1)

            # Update metrics
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Calculate final metrics
    accuracy = 100 * correct / total
    avg_loss = loss / len(val_loader)

    print(f'Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')

    return avg_loss, accuracy

In [8]:
validate(m, '../datasets/imagenet100/val/')

Validation Loss: 0.6637, Accuracy: 85.02%


(0.6637252083982143, 85.02)

In [9]:
inpt = torch.randn(1,3,224,224)
out_old = m(inpt)

In [10]:
def apply_transform(block1, block2, Q, keep_identity=True):
    with torch.no_grad():
        # Ensure that the out_channels of block1 is equal to the in_channels of block2
        assert Q.size()[0] == Q.size()[1], "Q needs to be a square matrix"
        n = Q.size()[0]
        assert block1.conv3.conv.out_channels == n and block2.conv1.conv.in_channels == n, "Mismatched channels between blocks"

        # Generate a random orthogonal matrix Q of order n
        n = block1.conv3.conv.out_channels
        
        # Calculate the inverse of Q
        Q_inv = torch.inverse(Q)

        # Modify the weights of conv layers in block1
        block1.conv3.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.conv3.fused_weight.data)
        block1.conv3.fused_bias.data = torch.einsum('ij,j->i', Q, block1.conv3.fused_bias.data)
        
        if keep_identity and isinstance(block1.skip, torch.nn.Identity):
            block1.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)
            block1.skip.weight.data = Q.unsqueeze(-1).unsqueeze(-1)
        if block1.act3 is not None:
            block1.skip.fused_weight.data = torch.einsum('ij,jklm->iklm', Q, block1.skip.fused_weight.data)
            block1.skip.fused_bias.data = torch.einsum('ij,j->i', Q, block1.skip.fused_bias.data)

        # Modify the weights of conv layers in block2
        block2.conv1.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.conv1.fused_weight.data)
        
        if keep_identity and isinstance(block2.skip, torch.nn.Identity):
            block2.skip = torch.nn.Conv2d(n, n, kernel_size=1, bias=False)
            block2.skip.weight.data = Q_inv.unsqueeze(-1).unsqueeze(-1)
        if block2.act3 is not None:
            block2.skip.fused_weight.data = torch.einsum('ki,jklm->jilm', Q_inv, block2.skip.fused_weight.data)


In [11]:
Q = torch.nn.init.orthogonal_(torch.empty(256, 256))
for i in range(5):
    apply_transform(m.layer3[i], m.layer3[i+1], Q, False)
apply_transform(m.layer3[5], m.layer4[0], Q, False)

In [13]:
validate(m, '../datasets/imagenet100/val/')

Validation Loss: 0.6637, Accuracy: 85.02%


(0.66372525997428, 85.02)

In [12]:
out_new = m(inpt)
print((out_new - out_old).abs().max().item())

1.537799835205078e-05
