In [13]:
import torch
from torch import nn
from torchvision import models

In [14]:
class HookWrapper(nn.Module):
    '''
        Wrapper class for Forward/Backward feature map extraction.

        - Usage -
        1) Make the instance of this class with the model and target layers.
        2) Forward/Backward it.
        3) Call get_features() will return the feature maps of previously forward/backwarded input.
        4) Back to 2).
    '''
    def __init__(self, model, target_layers):
        super(HookWrapper,self).__init__()
        self.model = model
        self.target_layers = target_layers
        self.features = [] # Size will be 4 after hook

        for name, module in model.named_children():
            if name in target_layers:
                module.register_forward_hook(self._extraction_fn)

    def _extraction_fn(self, module, input, output):
        # print(f"Test : {output.shape}")
        self.features.append(output)

    def forward(self, x):
        return self.model(x)
    
    def get_features(self): # Return feature list and make it empty.
        tmp = self.features
        self.features = []
        return tmp

# target_layers = ['conv1', 'layer1', 'layer2', 'layer3', 'layer4']
target_layers = ['conv1']
net_t = models.resnet18(pretrained=True)
net_s = models.resnet18(pretrained=True)
hook_net_t = HookWrapper(net_t, target_layers)
hook_net_s = HookWrapper(net_s, target_layers)

In [21]:
x_HR = torch.randn((1,3,224,224))
x_LR = torch.randn((1,3,112,112))
y = torch.tensor([900])

pred_t = hook_net_t(x_HR)
pred_s = hook_net_s(x_LR)
print(pred_s.shape)

features_t = hook_net_t.get_features()
features_s = hook_net_s.get_features()

downsample = nn.MaxPool2d(kernel_size=2)
criterion_a = nn.MSELoss()
criterion_b = nn.CrossEntropyLoss()

beta = 0.1
attention_loss = 0

for f_t, f_s in zip(features_t, features_s):
    f_t = downsample(f_t) # Transfer function of f_t -> f_s domain.
    if(f_t.shape != f_s.shape):
        continue
    print(f_t.shape, f_s.shape)

    attention_loss += criterion_a(f_t, f_s) # Activation-based Attention Transfer (Zagoryuko et al.)
    
classification_loss = criterion_b(pred_s, y)
total_loss = classification_loss + (beta/2)*attention_loss

print(total_loss)

torch.Size([1, 1000])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
tensor(8.6964, grad_fn=<AddBackward0>)
