### Small example illustrating:  
1. the similarity between 1x1 convolutions and fully connected layers  
2. the application of the QR decomposition to decompose inputs (i.e. features) into class specific and non-class specific features in cases where the 1x1 convolution is the final layer for e.g. a segmentation task.

In [2]:
import torch
from torch import nn

In [35]:
class Network(nn.Module):
    """single layer 1x1 conv network with additional
    fc layer with similar param space.
    
    Match params of both modules by calling params_conv2fc.
    """
    def __init__(self, in_channels=20, out_channels=5):
        super().__init__()
        
        self.conv = nn.Conv2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=1
        )
        
        self.fc = nn.Linear(
            in_features=in_channels,
            out_features=out_channels
        )
        
        
    def params_conv2fc(self):
        self.fc.weight.data = self.conv.weight.data.squeeze()
        self.fc.bias.data   = self.conv.bias.data
        
    def forward(self, x, mode='conv'):
        if mode == 'conv':
            return self.conv(x)
        elif mode == 'fc':
            # move channel to last
            x = x.movedim(1, 3)
            # apply fc layer
            x = self.fc(x)
            # move channel to second position to match conv layer
            return x.movedim(3, 1)


In [40]:
### - Verify that 1x1 conv is interchangable with fc layer
# init model and test input
model = Network()
#                   B,  C, H, W
x_in  = torch.ones((4, 20, 2, 2))
# match functions by copying conv params to fc layer
model.params_conv2fc()
# test if all close
torch.allclose(model(x_in, mode='conv'), model(x_in, mode='fc'))

True

In [85]:
### Extract weights and perform QR decomposition
# weights composed of weight and bias scalars. Later, we need to
# make sure to add a 1 as last channel when multiplying with Q.
# Consider, that our feature vectors are of shape (-1, C) in 
# our particular use case. Adapt accordingly

# shape (C, n_classes)
W = torch.cat(
    [
        model.conv.weight.data.squeeze(), 
        model.conv.bias.data.unsqueeze(1)
    ],
    dim=1
).T
# apply QR decomp
Q, R = torch.linalg.qr(W, mode='complete')
# init input features
x = torch.ones((100, 20))
# add bias term
x = torch.cat([x, torch.ones((100, 1))], dim=1)
print(f"Input feature shape (with bias):  {x.shape}")
# project all inputs onto used subspace
truncated_features = x @ Q[:,  :5]
# remaining features are slightly weird, b/c of bias term.
# Technically, it counts towards the features
remaining_features = x @ Q[:, 5:]
print(f"Shape of class specific features: {truncated_features.shape}")

Input feature shape (with bias):  torch.Size([100, 21])
Shape of class specific features: torch.Size([100, 5])
Shape of class specific features: torch.Size([100, 16])
